import numpy as np
import utils


# Forward Neural Network (FNN)
class FNN:
    # layers = 3 and shape = [input dim, xx, class num]
    def __init__(self, shape, activation='sigmoid') -> None:
        self._l = len(shape) - 1  # layer number
        self._w = []  # weight
        self._b = []  # bias
        self._z = [i for i in range(self._l)]  # wx+b
        self._a = [i for i in range(self._l)]  # activation(wx+b)
        self._x = None  # input feature
        self._n = None  # batch size
        self._m = None  # mean for normalization
        self._v = None  # variance for normalization
        self._act = activation  # activation function for output
        # He initialization
        for ly in range(self._l):
            self._w.append(np.random.randn(shape[ly], shape[ly + 1]) * np.sqrt(2 / shape[ly]))
            self._b.append(np.zeros((1, shape[ly + 1])))
        assert self._l == len(self._w) == len(self._b)

    def set_normalize(self, m, v):
        self._m = m
        self._v = v

    def forward(self, _x):
        if self._m is None:
            self._x = _x
        else:
            self._x = utils.do_normalize(_x, self._m, self._v)
        self._n = self._x.shape[0]
        # linear -> relu -> linear -> sigmoid
        self._z[0] = np.dot(self._x, self._w[0]) + self._b[0]
        for ly in range(1, self._l):
            self._a[ly - 1] = utils.relu(self._z[ly - 1])
            self._z[ly] = np.dot(self._a[ly - 1], self._w[ly]) + self._b[ly]
        if self._act == 'softmax':
            self._a[self._l - 1] = utils.softmax(self._z[self._l - 1])
        else:
            self._a[self._l - 1] = utils.sigmoid(self._z[self._l - 1])
        return self._a[self._l - 1]

    # if activation function is 'softmax', we use cee
    # else if activation func is 'sigmoid', we use mse
    def backward(self, _y, lr=0.1):
        # cee
        dz = self._a[self._l - 1] - _y
        # mse
        if self._act != 'softmax':
            dz = dz * self._a[self._l - 1] * (1 - self._a[self._l - 1])
        for ly in reversed(range(1, self._l)):
            dw = 1. / self._n * np.dot(self._a[ly - 1].T, dz)
            db = 1. / self._n * np.sum(dz, axis=0)
            self._w[ly] -= lr * dw
            self._b[ly] -= lr * db
            da = np.dot(dz, self._w[ly].T)
            dz = da.copy()
            dz[self._a[ly - 1] <= 0] = 0
        dw = 1. / self._n * np.dot(self._x.T, dz)
        db = 1. / self._n * np.sum(dz, axis=0)
        self._w[0] -= lr * dw
        self._b[0] -= lr * db

    # save model weights which in 'datasetName.txt'
    def save(self, _name):
        file = open(_name + '.txt', 'w')
        for fp in self._w:
            for i in range(fp.shape[0]):
                for j in range(fp.shape[1]):
                    file.write(str(fp[i][j]))
                    file.write(' ')
            file.write('\n')
        for fp in self._b:
            for i in range(fp.shape[0]):
                for j in range(fp.shape[1]):
                    file.write(str(fp[i][j]))
                    file.write(' ')
            file.write('\n')
        file.close()

    # load model weights which in 'datasetName.txt'
    def load(self, _path):
        file = open(_path, 'r')
        s = file.readlines()
        assert len(s) == len(self._w) * 2
        n = len(s) // 2
        for i in range(n):
            s[i] = s[i].strip()
            s[i] = s[i].split(' ')
            s[i] = list(map(float, s[i]))
            arr = np.array(s[i])
            arr = arr.reshape((self._w[i].shape[0], self._w[i].shape[1]))
            assert arr.shape == self._w[i].shape
            self._w[i] = arr
            s[i + n] = s[i + n].strip()
            s[i + n] = s[i + n].split(' ')
            s[i + n] = list(map(float, s[i + n]))
            arr = np.array(s[i + n])
            arr = arr.reshape((self._b[i].shape[0], self._b[i].shape[1]))
            assert arr.shape == self._b[i].shape
            self._b[i] = arr
